{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e63a29eb-cc7d-4956-b06d-1a3cdc47e16c",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Video processing with object detection using batch inference"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14d1d1fb-beaa-467c-8654-9d69c2f2193f",
   "metadata": {
    "tags": []
   },
   "source": [
    "This tutorial uses Ray and Anyscale for distributed data processing, PyTorch with a pre-trained Faster R-CNN model for object detection, and several other Python libraries for image and video handling. It shows how to:\n",
    "\n",
    "* Load a video from S3.\n",
    "* Split the video into individual frames.\n",
    "* Apply an object detection model to detect masks.\n",
    "* Draw bounding boxes and labels on each frame.\n",
    "* Generate a new video from the processed frames.\n",
    "\n",
    "This approach is very similar to the evaluation-focused pipeline in the previous notebook, in which it leverages batch inference with Ray Data, but unlike the previous notebook, this tutorial is purely inference—without computing metrics like mAP or IoU. Instead, it represents a real-world video analytics workflow, suitable for deployment in production environments.\n",
    "\n",
    "Here is the architecture diagram illustrates a distributed video processing pipeline using Ray Data batch inference on Anyscale for mask detection. \n",
    "\n",
    "<img\n",
    "  src=\"https://face-masks-data.s3.us-east-2.amazonaws.com/tutorial-diagrams/video_processing.png\"\n",
    "  alt=\"Object Detection Batch Inferece Pipeline - Video Processing\"\n",
    "  style=\"width:75%;\"\n",
    "/>\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53a82f7e-e0b1-4768-a68e-8ddb1205e801",
   "metadata": {
    "tags": []
   },
   "source": [
    "<div class=\"alert alert-block alert-warning\">\n",
    "  <b>Anyscale-specific configuration</b>\n",
    "  \n",
    "  <p>Note: This tutorial is optimized for the Anyscale platform. When running on open source Ray, additional configuration is required. For example, you’ll need to manually:</p>\n",
    "  \n",
    "  <ul>\n",
    "    <li>\n",
    "      <b>Configure your Ray Cluster:</b> Set up your multi-node environment, including head and worker nodes, and manage resource allocation like autoscaling and GPU/CPU assignments, without the Anyscale automation. See <a href=\"https://docs.ray.io/en/latest/cluster/getting-started.html\">Ray Clusters</a> for details.\n",
    "    </li>\n",
    "    <li>\n",
    "      <b>Manage dependencies:</b> Install and manage dependencies on each node since you won’t have Anyscale’s Docker-based dependency management. See <a href=\"https://docs.ray.io/en/latest/ray-core/handling-dependencies.html\">Environment Dependencies</a> for instructions on installing and updating Ray in your environment.\n",
    "    </li>\n",
    "    <li>\n",
    "      <b>Set up storage:</b> Configure your own distributed or shared storage system (instead of relying on Anyscale’s integrated cluster storage). See <a href=\"https://docs.ray.io/en/latest/train/user-guides/persistent-storage.html\">Configuring Persistent Storage</a> for suggestions on setting up shared storage solutions.\n",
    "    </li>\n",
    "  </ul>\n",
    "\n",
    "</div>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "653d2320",
   "metadata": {},
   "source": [
    "## Why use Ray and Anyscale for batch inference\n",
    "\n",
    "Batch inference with Ray and Anyscale is a more efficient way to handle large-scale inference tasks compared to the traditional method of serving and processing image requests one by one using APIs or endpoints. Instead of handling each request individually, batch inference processes multiple inputs simultaneously, leading to significant performance improvements. The key benefits include:\n",
    "\n",
    "* **Higher throughput**—Processing multiple images at once reduces the overhead of repeatedly loading the model and managing individual inference requests.\n",
    "* **Better resource utilization**—Ray uses GPUs and other hardware accelerators more efficiently when running inference in batches rather than performing single-image inferences, which can lead to underutilization.\n",
    "* **Lower latency for bulk processing**—While batch inference may introduce slight delays for individual requests, it significantly reduces the overall time required to process large datasets, making it ideal for offline or faster processing of videos.\n",
    "* **Scalability**—Batch inference with Ray allows distributed processing across multiple nodes, enabling efficient scaling for high-volume workloads.\n",
    "* **Automatic resource shutdown and cost efficiency**—Instead of keeping inference servers running continuously, once batch inference completes, Ray automatically shuts down idle resources, preventing unnecessary compute usage. You can also schedule batch processing during off-peak hours or using `spot instances`, leading to significant cost savings on compute resources.\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "937e436e-ed81-48da-a5ea-2cf2d24fd596",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Import libraries and define label mappings\n",
    "The first block of code imports all necessary libraries and sets up mappings for your classes like `with_mask`, `without_mask`, etc., and their corresponding colors for visualization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f24143e-f3cf-489b-a49c-e3cdf9299905",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import ray\n",
    "import numpy as np\n",
    "from PIL import Image, ImageDraw, ImageFont\n",
    "from io import BytesIO\n",
    "import cv2\n",
    "import torch\n",
    "from torchvision import models\n",
    "import os\n",
    "from smart_open import open as smart_open\n",
    "import io\n",
    "import ray\n",
    "\n",
    "\n",
    "CLASS_TO_LABEL = {\n",
    "    \"background\": 0,\n",
    "    \"with_mask\": 1,\n",
    "    \"without_mask\": 2,\n",
    "    \"mask_weared_incorrect\": 3\n",
    "}\n",
    "LABEL_TO_CLASS = {v: k for k, v in CLASS_TO_LABEL.items()}\n",
    "LABEL_COLORS = {\n",
    "    \"with_mask\": \"green\",\n",
    "    \"without_mask\": \"red\",\n",
    "    \"mask_weared_incorrect\": \"yellow\"\n",
    "}\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ede51725-44c9-48d7-a061-28d27b79858b",
   "metadata": {},
   "source": [
    "## Load and split video into frames\n",
    "Load the video file from an S3 bucket using the Ray Data API. Then convert it into individual frames. Each frame is stored along with its frame number.\n",
    "\n",
    "The Dataset should have two columns `frame` and `frame_index`.\n",
    "\n",
    "\n",
    "Note that `ray.data.read_videos` can also process directories containing multiple videos. In this case, consider setting the `include_paths` parameter to `True` to store file paths in the path column. This setting helps track which video each frame originated from.\n",
    "\n",
    "For more details, see: https://docs.ray.io/en/latest/data/api/doc/ray.data.read_videos.html#ray.data.read_videos"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a45534e9-5aa8-402b-a16f-a49161a6cfbe",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "ds_frames = ray.data.read_videos(\"s3://face-masks-data/videos/video1.mp4\")\n",
    "ds_frames.schema()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f8350b4",
   "metadata": {},
   "source": [
    "### Visualize some frames\n",
    "\n",
    "You can see there are in total 383 frames in the video. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ff46e46",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "\n",
    "# Convert to a Pandas DataFrame\n",
    "df = ds_frames.to_pandas()\n",
    "\n",
    "# Print the total number of frames\n",
    "print(\"Total number of frames:\", len(df))\n",
    "\n",
    "# Randomly sample 5 frames\n",
    "sampled_frames = df.sample(n=5, random_state=42).sort_values(by='frame_index')\n",
    "\n",
    "# Display sampled frames\n",
    "fig, axes = plt.subplots(1, 5, figsize=(20, 5))\n",
    "for i, (idx, row) in enumerate(sampled_frames.iterrows()):\n",
    "    frame_data = row[\"frame\"]\n",
    "    axes[i].imshow(frame_data)\n",
    "    axes[i].axis(\"off\")\n",
    "    axes[i].set_title(f\"Frame {row['frame_index']}\")\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3390203-f5ca-42f3-a6df-c3b5781c16da",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Load the object detection model\n",
    "Next, create a class that loads a pre-trained Faster R-CNN model from AWS S3 to the cluster storage and applies it to a batch of images. \n",
    "\n",
    "\n",
    "Define the `BatchObjectDetectionModel` class to encapsulate the detection logic, which you can later use with the `map_batches` function in Ray Data.\n",
    "\n",
    "Ray Data allows for two approaches when applying transformations like `map` or `map_batches`:\n",
    "\n",
    "* **Functions**: These use stateless Ray tasks, which are ideal for simple operations that don’t require loading heavyweight models.\n",
    "* **Classes**: These use stateful Ray actors, making them well-suited for more complex tasks involving heavyweight models—**exactly what you need in this case**.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d2fc2fa-7374-43e1-80fb-6bb62ab12b96",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "# Paths.\n",
    "remote_model_path = \"s3://face-masks-data/finetuned-models/fasterrcnn_model_mask_detection.pth\"\n",
    "cluster_model_path = \"/mnt/cluster_storage/fasterrcnn_model_mask_detection.pth\"  \n",
    "\n",
    "# Download model only once.\n",
    "if not os.path.exists(cluster_model_path):\n",
    "    with smart_open(remote_model_path, 'rb') as s3_file:\n",
    "        with open(cluster_model_path, 'wb') as local_file:\n",
    "            local_file.write(s3_file.read())\n",
    "\n",
    "# Load the model (driver verifies it works).\n",
    "loaded_model = models.detection.fasterrcnn_resnet50_fpn(num_classes=len(CLASS_TO_LABEL))\n",
    "loaded_model.load_state_dict(torch.load(cluster_model_path, map_location=\"cpu\"))\n",
    "loaded_model.eval()\n",
    "\n",
    "\n",
    "class BatchObjectDetectionModel:\n",
    "    def __init__(self):\n",
    "        self.model = loaded_model\n",
    "        if torch.cuda.is_available():\n",
    "            self.model = self.model.cuda()\n",
    "\n",
    "    def __call__(self, batch: dict) -> dict:\n",
    "        predictions = []\n",
    "        for image_np in batch[\"frame\"]:\n",
    "            image_tensor = torch.from_numpy(image_np).permute(2, 0, 1).float() / 255.0\n",
    "            if torch.cuda.is_available():\n",
    "                image_tensor = image_tensor.cuda()\n",
    "            with torch.no_grad():\n",
    "                pred = self.model([image_tensor])[0]\n",
    "            predictions.append({\n",
    "                \"boxes\": pred[\"boxes\"].detach().cpu().numpy(),\n",
    "                \"labels\": pred[\"labels\"].detach().cpu().numpy(),\n",
    "                \"scores\": pred[\"scores\"].detach().cpu().numpy()\n",
    "            })\n",
    "        batch[\"predictions\"] = predictions\n",
    "        return batch\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d99f41d-74af-4a04-ae7e-4f2444073f71",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Apply the object detection model\n",
    "\n",
    "Apply the BatchObjectDetectionModel to each batch of frames using the Ray Data `map_batches` method. This step performs object detection on all frames in parallel.\n",
    "\n",
    "### Understand key parameters\n",
    "* `concurrency`: Defines the number of parallel workers processing batches. Increasing this value enables more workers to process data simultaneously, speeding up computation but requiring more system resources (CPU, memory, and GPUs).\n",
    "* `batch_size`: Specifies how many frames each worker processes at a time. A larger batch size increases throughput but may require more GPU memory. Finding the optimal batch size depends on the available memory of your GPUs.\n",
    "* `num_gpus`: Sets the number of GPUs each worker can use. In this case, you allocate 1 GPU to each worker, meaning the total number of GPUs used is `concurrency` * `num_gpus`.\n",
    "\n",
    "### Adjust for performance:\n",
    "* If your system has more GPUs, you can increase concurrency to use more parallel workers.\n",
    "* If you have limited GPU memory, try reducing `batch_size` to avoid memory overflow.\n",
    "\n",
    "\n",
    "For more information, see: https://docs.ray.io/en/latest/data/api/doc/ray.data.Dataset.map_batches.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1d037d5-b48c-4eba-b3c3-07a6ca74a1fb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Apply object detection model.\n",
    "ds_predicted = ds_frames.map_batches(\n",
    "    BatchObjectDetectionModel, \n",
    "    concurrency=2,   # Specify 2 workers.\n",
    "    batch_size=8,\n",
    "    num_gpus=1 # Each worker uses 1 GPU. In total Ray Data uses 2 GPUs.\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c0d81761-2ee9-420d-acae-8ca1b418b392",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Draw bounding boxes and labels on each frame\n",
    "Next, define a function to draw bounding boxes and labels on the detected objects. This step uses the predictions from your model and the mappings you defined earlier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb9891b3-53e4-49f4-8da9-faa02f71c603",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "# Draw bounding boxes and labels on each frame.\n",
    "def draw_boxes(row):\n",
    "    image_np = row[\"frame\"]\n",
    "    predictions = row[\"predictions\"]\n",
    "    boxes = predictions[\"boxes\"]\n",
    "    labels = predictions[\"labels\"]\n",
    "    scores = predictions[\"scores\"]\n",
    "    \n",
    "    confidence_threshold = 0.5\n",
    "    valid = scores > confidence_threshold\n",
    "    boxes = boxes[valid]\n",
    "    labels = labels[valid]\n",
    "    scores = scores[valid]\n",
    "\n",
    "    pil_image = Image.fromarray(image_np)\n",
    "    draw = ImageDraw.Draw(pil_image)\n",
    "    font = ImageFont.load_default()\n",
    "\n",
    "    for box, label, score in zip(boxes, labels, scores):\n",
    "        x1, y1, x2, y2 = box\n",
    "        class_name = LABEL_TO_CLASS.get(label, \"unknown\")\n",
    "        color = LABEL_COLORS.get(class_name, \"white\")\n",
    "        \n",
    "        # Draw bounding box.\n",
    "        draw.rectangle([x1, y1, x2, y2], outline=color, width=2)\n",
    "        \n",
    "        # Prepare text.\n",
    "        text = f\"{class_name} {score:.2f}\"\n",
    "        text_bbox = draw.textbbox((0, 0), text, font=font)\n",
    "        text_width = text_bbox[2] - text_bbox[0]\n",
    "        text_height = text_bbox[3] - text_bbox[1]\n",
    "        \n",
    "        # Draw text background.\n",
    "        draw.rectangle(\n",
    "            [x1, y1 - text_height - 2, x1 + text_width, y1],\n",
    "            fill=color\n",
    "        )\n",
    "        \n",
    "        # Draw text.\n",
    "        text_color = \"black\" if color == \"yellow\" else \"white\"\n",
    "        draw.text(\n",
    "            (x1, y1 - text_height - 2),\n",
    "            text,\n",
    "            fill=text_color,\n",
    "            font=font\n",
    "        )\n",
    "    \n",
    "    return {\n",
    "        \"frame\": np.array(pil_image),\n",
    "        \"frame_index\": row[\"frame_index\"]\n",
    "    }\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "ds_visualized = ds_predicted.map(draw_boxes)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ecdcf3de-1f6b-42b5-95af-fc9de354629e",
   "metadata": {},
   "source": [
    "## Collect and sort processed frames\n",
    "After processing, collect all frames and sort them by frame number to ensure the video plays in the correct order.\n",
    "\n",
    "Note that Ray Data uses lazy execution with `map` and `map_batches` in the previous steps, meaning Ray Data performs no actions immediately. To force computation and execute the pipeline, use ds_visualized.take_all().\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5056d1a3-c728-4d3b-84d2-cc06e9bc11dc",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "processed_frames = ds_visualized.take_all()\n",
    "print(\"processed_frames\", len(processed_frames))\n",
    "sorted_frames = sorted(processed_frames, key=lambda x: x[\"frame_index\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db14d87c-2153-4f41-80a1-d8b7139a812a",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Generate the output video\n",
    "Finally, generate a new video from the processed frames using OpenCV. Generate a video in `webm` format and display it in this Jupyter notebook. \n",
    "\n",
    "You can also modify the code to generate MP4 or the other formats. They work well when you play locally, but some browsers, including the Jupyter Notebook interface, which relies on the browser's video capabilities, expect the MP4 file to have the `moov` atom (metadata) at the beginning of the file to enable streaming. In many cases, the `cv2.VideoWriter` might place this metadata at the end, which doesn't affect desktop players as much but can cause issues when embedding in a browser. Formats like `webm` are often more friendly for browser playback without requiring extra post-processing steps.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45bb0622",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate output video in WebM format.\n",
    "output_video_path = \"./saved_videos/output_video.webm\"  # Save video to .webm format.\n",
    "os.makedirs(os.path.dirname(output_video_path), exist_ok=True)  # Create directory if needed.\n",
    "\n",
    "if sorted_frames:\n",
    "    # Get video properties from the first frame.\n",
    "    height, width, _ = sorted_frames[0][\"frame\"].shape\n",
    "\n",
    "    # Initialize video writer with VP8 codec for WebM.\n",
    "    fourcc = cv2.VideoWriter_fourcc(*'VP80')\n",
    "    video_writer = cv2.VideoWriter(output_video_path, fourcc, 30.0, (width, height))\n",
    "    \n",
    "    for frame in sorted_frames:\n",
    "        # Convert RGB to BGR for OpenCV.\n",
    "        bgr_frame = cv2.cvtColor(frame[\"frame\"], cv2.COLOR_RGB2BGR)\n",
    "        video_writer.write(bgr_frame)\n",
    "    \n",
    "    video_writer.release()\n",
    "    print(f\"Output video saved to: {output_video_path}\")\n",
    "else:\n",
    "    print(\"No frames available for video creation.\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27345e09-6d74-476c-80f5-3a85bf07dfa5",
   "metadata": {},
   "source": [
    "## Inspect the output video\n",
    "\n",
    "You can now visualize the video within the Jupyter Notebook using the following code. Alternatively, download the video locally to verify that the object detection model rendered every frame with masks.\n",
    "\n",
    "The model performs well initially; however, as the person moves the pen in front of his face, its accuracy decreases, occasionally producing incorrect detection results.\n",
    "\n",
    "This behavior is a common challenge in object detection, especially when the model lacks sufficient training data for such scenarios. To mitigate this issue, consider collecting additional data that specifically addresses it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a02b872",
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "\n",
    "video_path = \"saved_videos/output_video.webm\"\n",
    "\n",
    "video_html = f\"\"\"\n",
    "<video width=\"640\" height=\"360\" controls>\n",
    "  <source src=\"{video_path}\" type=\"video/mp4\">\n",
    "  Your browser does not support the video tag.\n",
    "</video>\n",
    "\"\"\"\n",
    "\n",
    "display(HTML(video_html))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
